import torch
import torch.nn as nn
import numpy as np
from matplotlib import pyplot as plt

def expected_cross_entropy(outputs: torch.tensor, targets: torch.tensor):
    """
    compute expected cross entropy for baselines Max Entropy
    :param: outputs (torch tensor): output before softmax
    :param: targets (torch tensor): target before softmax
    """
    softmax = nn.Softmax(dim=1)
    log_softmax = nn.LogSoftmax(dim=1)
    targets = softmax(targets)
    outputs = log_softmax(outputs)
    loss = - (targets * outputs).sum(dim=1)
    return loss

def output_softmax(output: torch.tensor):
    """
    compute model prediction probability
    :param: outputs (torch tensor): output before softmax
    """
    softmax = nn.Softmax(dim=1)
    outputs = softmax(output)
    return outputs

def soft_cross_entropy(outputs: torch.tensor, targets: torch.tensor):
    """
    compute expected loss given the pseudo label
    :param: outputs (torch tensor): output after softmax
    :param: targets (torch tensor): target after softmax
    """
    loss = -(targets*torch.log(outputs)).sum(dim=0)
    return loss

        